import os
import re
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
from tkinterdnd2 import DND_FILES, TkinterDnD
import multiprocessing
import functools
from concurrent.futures import ThreadPoolExecutor
import threading

def process_text(text, remove_timestamp=True, minute_interval=0):
    """
    处理字幕文本，根据选项过滤内容并插入时间信息
    """
    lines = text.splitlines()
    filtered = []
    # 支持 1-2位时分秒 + 1-3位毫秒
    timestamp_pattern = re.compile(
        r'^\d{1,2}:\d{1,2}:\d{1,2},\d{1,3} --> \d{1,2}:\d{1,2}:\d{1,2},\d{1,3}$'
    )
    
    # 查找第一个和最后一个时间戳，用于添加到开头和结尾
    first_timestamp = None
    last_timestamp = None
    
    for line in lines:
        stripped_line = line.strip()
        if timestamp_pattern.match(stripped_line):
            start_time = stripped_line.split(' --> ')[0]
            if first_timestamp is None:
                first_timestamp = start_time
            last_timestamp = start_time
    
    # 用于插入时间标记
    time_markers = []
    if minute_interval > 0:
        # 解析时间戳并计算分钟标记
        for i, line in enumerate(lines):
            if timestamp_pattern.match(line.strip()):
                start_time = line.split(' --> ')[0]
                hours, minutes, seconds = map(int, re.sub(r',\d+$', '', start_time).split(':'))
                total_minutes = hours * 60 + minutes
                marker_minute = (total_minutes // minute_interval) * minute_interval
                if marker_minute > 0 and (i == 0 or marker_minute != time_markers[-1][1] if time_markers else True):
                    time_markers.append((i, marker_minute))
    
    # 处理每行内容
    marker_index = 0
    for i, line in enumerate(lines):
        stripped_line = line.strip()
        
        # 插入时间标记
        if minute_interval > 0 and marker_index < len(time_markers) and i >= time_markers[marker_index][0]:
            filtered.append("*************************************************************************************")
            filtered.append(f"第{time_markers[marker_index][1]}分钟")
            marker_index += 1
        
        if re.match(r'^\d+$', stripped_line):  # 序号行
            continue
        elif remove_timestamp and timestamp_pattern.match(stripped_line):  # 时间戳行
            continue
        elif re.match(r'^\s*$', stripped_line):  # 空行
            continue
        else:
            filtered.append(line)
    
    result = []
    
    # 添加开头时间信息
    if first_timestamp:
        result.append("*************************************************************************************")
        h, m, s = map(int, re.sub(r',\d+$', '', first_timestamp).split(':'))
        if h > 0:
            result.append(f"第{h}小时{m}分{s}秒")
        else:
            result.append(f"第{m}分{s}秒")
    
    # 添加主要内容
    result.extend(filtered)
    
    # 添加结尾时间信息
    if last_timestamp:
        result.append("*************************************************************************************")
        h, m, s = map(int, re.sub(r',\d+$', '', last_timestamp).split(':'))
        if h > 0:
            result.append(f"第{h}小时{m}分{s}秒")
        else:
            result.append(f"第{m}分{s}秒")
    
    return '\n'.join(result)

def split_srt_file(srt_file, lines_per_file, keep_original_timestamp=True):
    """
    拆分字幕文件，按照指定的条数进行切割
    """
    print(f"正在拆分: {srt_file}，每个文件 {lines_per_file} 条字幕")
    
    try:
        with open(srt_file, 'r', encoding='utf-8-sig') as f:
            content = f.read()
    except Exception as e:
        print(f"读取文件 {srt_file} 出错: {e}")
        return False, str(e)
    
    try:
        # 按照字幕块拆分内容
        subtitle_blocks = re.split(r'\n\s*\n', content.strip())
        total_blocks = len(subtitle_blocks)
        
        # 计算需要生成的文件数量
        num_files = (total_blocks + lines_per_file - 1) // lines_per_file
        
        base_name = os.path.splitext(srt_file)[0]
        
        # 存储原始时间戳信息，用于恢复
        original_timestamps = []
        if keep_original_timestamp:
            timestamp_pattern = re.compile(r'\d{1,2}:\d{1,2}:\d{1,2},\d{1,3}\s*-->\s*\d{1,2}:\d{1,2}:\d{1,2},\d{1,3}')
            for block in subtitle_blocks:
                lines = block.strip().split('\n')
                if len(lines) >= 2:
                    timestamp_line = lines[1].strip()
                    if timestamp_pattern.match(timestamp_line):
                        original_timestamps.append(timestamp_line)
                    else:
                        original_timestamps.append(None)
                else:
                    original_timestamps.append(None)
        
        created_files = []  # 存储创建的文件路径
        
        for i in range(num_files):
            start_idx = i * lines_per_file
            end_idx = min((i + 1) * lines_per_file, total_blocks)
            
            output_lines = []
            block_counter = 1  # 重置序号计数器
            
            for j, block in enumerate(subtitle_blocks[start_idx:end_idx]):
                lines = block.strip().split('\n')
                
                # 确保块有足够的内容
                if len(lines) < 2:
                    continue
                
                # 处理时间戳格式
                timestamp_line = lines[1].strip()
                if re.match(r'\d{1,2}:\d{1,2}:\d{1,2},\d{1,3}\s*-->\s*\d{1,2}:\d{1,2}:\d{1,2},\d{1,3}', timestamp_line):
                    # 如果需要保留原始时间戳且存在原始时间戳，则使用原始时间戳
                    if keep_original_timestamp and len(original_timestamps) > start_idx + j and original_timestamps[start_idx + j]:
                        timestamp_line = original_timestamps[start_idx + j]
                    else:
                        parts = re.split(r'\s*-->\s*', timestamp_line)
                        formatted_timestamps = []
                        
                        for part in parts:
                            if ',' in part:
                                # 分离时间和毫秒
                                time_part, ms_part = part.split(',', 1)
                                # 处理时间部分 (HH:MM:SS)
                                time_components = time_part.split(':')
                                if len(time_components) == 3:
                                    # 确保小时、分钟、秒都是两位数
                                    hours = time_components[0].zfill(2)
                                    minutes = time_components[1].zfill(2)
                                    seconds = time_components[2].zfill(2)
                                    time_part = f"{hours}:{minutes}:{seconds}"
                                # 确保毫秒是3位数
                                ms_part = ms_part.ljust(3, '0')[:3]
                                formatted_timestamps.append(f"{time_part},{ms_part}")
                            else:
                                formatted_timestamps.append(part)
                        
                        # 使用标准的分隔符格式
                        timestamp_line = ' --> '.join(formatted_timestamps)
                
                # 重建字幕块
                output_lines.append(str(block_counter))  # 新序号
                output_lines.append(timestamp_line)      # 格式化的时间戳
                
                # 确保有字幕文本内容
                if len(lines) > 2:
                    output_lines.extend(lines[2:])       # 文本内容
                else:
                    # 如果没有文本内容，添加一个空白字幕避免格式错误
                    output_lines.append(" ")
                
                output_lines.append('')                  # 块结束空行
                block_counter += 1
            
            # 组合内容 (确保最后有空行)
            part_content = '\n'.join(output_lines).strip() + '\n\n'
            
            # 写入新文件
            output_file = f"{base_name}_part{i+1}.srt"
            with open(output_file, 'w', encoding='utf-8-sig') as f:
                f.write(part_content)
            created_files.append(output_file)  # 添加到创建的文件列表
            print(f"已保存为: {output_file}")
        
        return True, created_files  # 返回成功状态和创建的文件列表
    except Exception as e:
        print(f"拆分文件 {srt_file} 出错: {e}")
        return False, str(e)

def process_srt_file(srt_file, remove_timestamp=True, minute_interval=0):
    base_name = os.path.splitext(srt_file)[0]
    txt_file = f"{base_name}.txt"
    print(f"正在处理: {srt_file}")
    try:
        with open(srt_file, 'r', encoding='utf-8-sig') as f:
            content = f.read()
    except Exception as e:
        print(f"读取文件 {srt_file} 出错: {e}")
        return False, str(e)
    
    try:
        processed = process_text(content, remove_timestamp, minute_interval)
        with open(txt_file, 'w', encoding='utf-8') as f:
            f.write(processed)
        print(f"已保存为: {txt_file}")
        return True, txt_file
    except Exception as e:
        print(f"写入文件 {txt_file} 出错: {e}")
        return False, str(e)

# 添加一个非进程池版本的转换函数，用于直接转换单个文件
def convert_single_srt_to_txt(srt_file, remove_timestamp=True, minute_interval=0):
    """直接转换单个SRT文件为TXT，不使用进程池"""
    try:
        print(f"直接处理单个文件: {srt_file}")
        base_name = os.path.splitext(srt_file)[0]
        txt_file = f"{base_name}.txt"
        
        with open(srt_file, 'r', encoding='utf-8-sig') as f:
            content = f.read()
        
        processed = process_text(content, remove_timestamp, minute_interval)
        
        with open(txt_file, 'w', encoding='utf-8') as f:
            f.write(processed)
            
        print(f"成功保存为: {txt_file}")
        return True, txt_file
    except Exception as e:
        error_msg = f"转换文件 {srt_file} 失败: {str(e)}"
        print(error_msg)
        return False, error_msg

def add_file_to_list(file_path):
    # 防重复检验：检查文件是否已经在列表中
    existing_files = file_listbox.get(0, tk.END)
    if file_path not in existing_files:
        file_listbox.insert(tk.END, file_path)
    else:
        print(f"文件已存在: {file_path}")
        # 可选：高亮显示已存在的文件
        index = existing_files.index(file_path)
        file_listbox.selection_clear(0, tk.END)
        file_listbox.selection_set(index)
        file_listbox.see(index)

def update_progress(current, total):
    progress['value'] = (current / total) * 100
    progress_label.config(text=f"进度: {current}/{total} ({int((current/total)*100)}%)")
    root.update_idletasks()

def update_progress_text(text):
    status_label.config(text=text)
    root.update_idletasks()

def convert_files():
    selected_files = file_listbox.get(0, tk.END)
    if not selected_files:
        messagebox.showinfo("提示", "请先拖动 .srt 文件到窗口中。")
        return
    
    # 获取选项值
    remove_timestamp = remove_timestamp_var.get()
    try:
        minute_interval = int(minute_interval_var.get())
    except ValueError:
        minute_interval = 0
    
    # 隐藏占位标签，显示进度条
    placeholder_label.pack_forget()
    progress.pack(pady=5, fill=tk.X)
    progress_label.pack(pady=2)
    status_label.pack(pady=2)
    update_progress(0, len(selected_files))
    
    results = []
    
    def process_complete_callback():
        messagebox.showinfo("提示", "转换完成。")
        # 隐藏进度条，显示占位标签
        progress.pack_forget()
        progress_label.pack_forget()
        status_label.pack_forget()
        placeholder_label.pack(pady=5)
    
    def process_in_thread():
        nonlocal results
        
        # 检查是否使用多进程（只有文件数量大于1个时才使用多进程）
        if use_multiprocessing.get() and len(selected_files) > 1:
            # 使用多进程处理文件
            max_workers = min(multiprocessing.cpu_count(), 4)  # 限制最大进程数
            with multiprocessing.Pool(processes=max_workers) as pool:
                for i, result in enumerate(pool.starmap(process_srt_file, 
                                                    [(f, remove_timestamp, minute_interval) for f in selected_files])):
                    results.append(result)
                    update_progress(i + 1, len(selected_files))
                    update_progress_text(f"处理: {os.path.basename(selected_files[i])}")
        else:
            # 使用单进程处理文件
            for i, file_path in enumerate(selected_files):
                update_progress_text(f"处理: {os.path.basename(file_path)}")
                result = process_srt_file(file_path, remove_timestamp, minute_interval)
                results.append(result)
                update_progress(i + 1, len(selected_files))
        
        # 完成后在主线程更新UI
        root.after(0, process_complete_callback)
    
    # 在单独的线程中运行处理逻辑，避免UI阻塞
    threading.Thread(target=process_in_thread, daemon=True).start()

def split_files():
    selected_files = file_listbox.get(0, tk.END)
    if not selected_files:
        messagebox.showinfo("提示", "请先拖动 .srt 文件到窗口中。")
        return
    
    # 在主进程中获取行数值
    try:
        lines_per_file = int(split_lines_var.get())
        if lines_per_file <= 0:
            messagebox.showerror("错误", "每个文件字幕条数必须大于0")
            return
    except ValueError:
        messagebox.showerror("错误", "请输入有效的数字")
        return
    
    # 获取是否保留原始时间戳
    keep_original_timestamp = keep_timestamp_var.get()
    
    # 获取是否需要转换为txt
    convert_to_txt = convert_after_split_var.get()
    
    # 获取txt转换设置
    remove_timestamp = remove_timestamp_var.get()
    try:
        minute_interval = int(minute_interval_var.get())
    except ValueError:
        minute_interval = 0
    
    # 隐藏占位标签，显示进度条
    placeholder_label.pack_forget()
    progress.pack(pady=5, fill=tk.X)
    progress_label.pack(pady=2)
    status_label.pack(pady=2)
    update_progress(0, len(selected_files))
    
    results = []
    txt_results = []
    all_split_files = []  # 存储所有拆分后的文件
    
    def process_complete_callback():
        # 分析结果
        success_count = sum(1 for success, _ in results if success)
        txt_success_count = 0
        if convert_to_txt and txt_results:
            txt_success_count = sum(1 for success, _ in txt_results if success)
            messagebox.showinfo("提示", f"拆分完成。成功: {success_count}, 失败: {len(results) - success_count}\n"
                              f"转换txt完成。成功: {txt_success_count}, 失败: {len(txt_results) - txt_success_count}")
        else:
            messagebox.showinfo("提示", f"拆分完成。成功: {success_count}, 失败: {len(results) - success_count}")
        # 隐藏进度条，显示占位标签
        progress.pack_forget()
        progress_label.pack_forget()
        status_label.pack_forget()
        placeholder_label.pack(pady=5)
    
    def process_in_thread():
        nonlocal results, txt_results, all_split_files
        
        # 检查是否使用多进程（只有文件数量大于1个时才使用多进程）
        if use_multiprocessing.get() and len(selected_files) > 1:
            # 使用多进程处理文件
            max_workers = min(multiprocessing.cpu_count(), 4)  # 限制最大进程数
            with multiprocessing.Pool(processes=max_workers) as pool:
                for i, result in enumerate(pool.starmap(split_srt_file, 
                                                    [(f, lines_per_file, keep_original_timestamp) for f in selected_files])):
                    results.append(result)
                    if result[0]:  # 如果拆分成功
                        # 直接使用返回的文件列表
                        all_split_files.extend(result[1])
                    update_progress(i + 1, len(selected_files))
                    update_progress_text(f"拆分: {os.path.basename(selected_files[i])}")
        else:
            # 使用单进程处理文件
            for i, file_path in enumerate(selected_files):
                update_progress_text(f"拆分: {os.path.basename(file_path)}")
                result = split_srt_file(file_path, lines_per_file, keep_original_timestamp)
                results.append(result)
                if result[0]:  # 如果拆分成功
                    # 直接使用返回的文件列表
                    all_split_files.extend(result[1])
                update_progress(i + 1, len(selected_files))
        
        # 如果需要转换为txt
        if convert_to_txt and all_split_files:
            # 如果有拆分后的文件，开始转换为txt
            print(f"需要转换的拆分文件数量: {len(all_split_files)}")
            print(f"文件列表: {all_split_files}")
            update_progress_text(f"开始将拆分后的{len(all_split_files)}个SRT文件转换为TXT...")
            update_progress(0, len(all_split_files))
            
            # 使用单线程和直接调用方式进行转换，避免多进程问题
            for i, file_path in enumerate(all_split_files):
                file_path = str(file_path).strip()  # 确保文件路径没有多余空格
                if os.path.exists(file_path) and os.path.isfile(file_path):
                    update_progress_text(f"转换: {os.path.basename(file_path)}")
                    try:
                        # 直接调用转换函数，不使用进程池
                        result = convert_single_srt_to_txt(file_path, remove_timestamp, minute_interval)
                        txt_results.append(result)
                        print(f"已转换 ({i+1}/{len(all_split_files)}): {file_path} -> {result[1] if result[0] else result[1]}")
                    except Exception as e:
                        error_msg = f"转换出错: {str(e)}"
                        print(error_msg)
                        txt_results.append((False, error_msg))
                else:
                    error_msg = f"文件不存在或不是文件: {file_path}"
                    print(error_msg)
                    txt_results.append((False, error_msg))
                
                update_progress(i + 1, len(all_split_files))
        
        # 完成后在主线程更新UI
        root.after(0, process_complete_callback)
    
    # 在单独的线程中运行处理逻辑，避免UI阻塞
    threading.Thread(target=process_in_thread, daemon=True).start()

def on_drop(event):
    file_paths = root.tk.splitlist(event.data)
    added_count = 0
    skipped_count = 0
    
    for file_path in file_paths:
        if file_path.endswith('.srt'):
            # 检查文件是否已存在于列表中
            existing_files = file_listbox.get(0, tk.END)
            if file_path not in existing_files:
                add_file_to_list(file_path)
                added_count += 1
            else:
                skipped_count += 1
    
    # 如果有跳过的文件，显示提示信息
    if skipped_count > 0:
        status_label.config(text=f"已添加: {added_count} 个文件，跳过重复: {skipped_count} 个文件")
        status_label.pack(pady=2)
        placeholder_label.pack_forget()
        # 3秒后恢复占位标签
        root.after(3000, lambda: [status_label.pack_forget(), placeholder_label.pack(pady=5)])

def clear_file_list():
    file_listbox.delete(0, tk.END)

if __name__ == '__main__':
    # 创建主窗口
    root = TkinterDnD.Tk()
    root.title("SRT 字幕转换器")
    root.geometry("590x780")  # 优化后的窗口高度
    
    # 创建主框架，减少内边距
    main_frame = tk.Frame(root)
    main_frame.pack(fill=tk.BOTH, expand=True, padx=5, pady=5)
    
    # 使用主框架而不是滚动框架
    scrollable_frame = main_frame
    
    # 创建文件列表框，减少高度和边距
    list_frame = tk.Frame(scrollable_frame)
    list_frame.pack(fill=tk.X, pady=5)
    
    list_label = tk.Label(list_frame, text="文件列表:")
    list_label.pack(anchor=tk.W, padx=5, pady=2)
    
    file_listbox = tk.Listbox(list_frame, width=75, height=8)  # 减少高度
    file_listbox.pack(fill=tk.X, padx=5)
    # 允许列表框也支持拖放功能
    file_listbox.drop_target_register(DND_FILES)
    file_listbox.dnd_bind('<<Drop>>', on_drop)

    # 创建拖放区域，减少高度
    drop_label = tk.Label(scrollable_frame, text="将 .srt 文件拖动到此处", bg="lightgrey", width=75, height=2)
    drop_label.pack(pady=5)
    drop_label.drop_target_register(DND_FILES)
    drop_label.dnd_bind('<<Drop>>', on_drop)

    # 全局设置区域，减少内部填充
    global_frame = tk.LabelFrame(scrollable_frame, text="全局设置", padx=5, pady=2)
    global_frame.pack(fill=tk.X, pady=5)
    
    # 创建多进程选项
    use_multiprocessing = tk.BooleanVar(value=True)  # 默认勾选
    multiprocessing_checkbox = tk.Checkbutton(
        global_frame, 
        text="使用多进程处理（适合批量转换）", 
        variable=use_multiprocessing
    )
    multiprocessing_checkbox.pack(pady=2, anchor=tk.W)

    # 拆分字幕设置区域，减少内部填充
    split_settings_frame = tk.LabelFrame(scrollable_frame, text="拆分字幕设置", padx=5, pady=2)
    split_settings_frame.pack(fill=tk.X, pady=5)
    
    # 创建拆分字幕设置框架
    split_line_frame = tk.Frame(split_settings_frame)
    split_line_frame.pack(pady=2, fill=tk.X)
    
    split_label = tk.Label(split_line_frame, text="拆分时每个文件字幕条数:")
    split_label.pack(side=tk.LEFT, padx=5)
    
    split_lines_var = tk.StringVar(value="1000")  # 默认值为1000
    split_lines_spinbox = ttk.Spinbox(
        split_line_frame, 
        from_=100, 
        to=10000, 
        increment=100,
        textvariable=split_lines_var,
        width=10
    )
    split_lines_spinbox.pack(side=tk.LEFT)
    
    # 添加保留原始时间戳选项
    keep_timestamp_var = tk.BooleanVar(value=True)  # 默认勾选
    keep_timestamp_checkbox = tk.Checkbutton(
        split_settings_frame, 
        text="拆分按照原始时间戳", 
        variable=keep_timestamp_var
    )
    keep_timestamp_checkbox.pack(pady=2, anchor=tk.W)
    
    # 添加拆分后转换txt选项
    convert_after_split_var = tk.BooleanVar(value=False)  # 默认不勾选
    convert_after_split_checkbox = tk.Checkbutton(
        split_settings_frame, 
        text="拆分后批量转换为txt", 
        variable=convert_after_split_var
    )
    convert_after_split_checkbox.pack(pady=2, anchor=tk.W)
    
    # 批量转换txt设置区域，减少内部填充
    convert_settings_frame = tk.LabelFrame(scrollable_frame, text="转换txt设置", padx=5, pady=2)
    convert_settings_frame.pack(fill=tk.X, pady=5)
    
    # 创建去除时间戳选项
    remove_timestamp_var = tk.BooleanVar(value=True)  # 默认勾选
    remove_timestamp_checkbox = tk.Checkbutton(
        convert_settings_frame, 
        text="去除字幕时间戳", 
        variable=remove_timestamp_var
    )
    remove_timestamp_checkbox.pack(pady=2, anchor=tk.W)
    
    # 创建时间标记间隔设置
    time_marker_frame = tk.Frame(convert_settings_frame)
    time_marker_frame.pack(pady=2, fill=tk.X)
    
    time_marker_label = tk.Label(time_marker_frame, text="转换后的txt每隔多少时间插入时间信息:")
    time_marker_label.pack(side=tk.LEFT, padx=5)
    
    minute_interval_var = tk.StringVar(value="10")  # 默认值为10分钟
    minute_interval_spinbox = ttk.Spinbox(
        time_marker_frame, 
        from_=0, 
        to=120, 
        increment=5,
        textvariable=minute_interval_var,
        width=5
    )
    minute_interval_spinbox.pack(side=tk.LEFT)
    
    # 创建状态提示框架，减少内部填充
    status_frame = tk.LabelFrame(scrollable_frame, text="状态提示", padx=5, pady=5)
    status_frame.pack(fill=tk.X, pady=5)
    
    # 创建进度条和状态标签（初始不显示）
    progress = ttk.Progressbar(status_frame, orient="horizontal", length=300, mode="determinate")
    progress_label = tk.Label(status_frame, text="进度: 0/0 (0%)")
    status_label = tk.Label(status_frame, text="")
    
    # 添加占位标签，保持状态框高度但减少高度
    placeholder_label = tk.Label(status_frame, text="准备就绪，等待操作...", height=1)
    placeholder_label.pack(pady=2)
    
    # 创建按钮框架，减少边距
    button_frame = tk.Frame(scrollable_frame)
    button_frame.pack(fill=tk.X, pady=5)

    # 使用grid布局固定按钮位置
    button_layout_frame = tk.Frame(button_frame)
    button_layout_frame.pack(fill=tk.X)
    button_layout_frame.columnconfigure(0, weight=1)
    button_layout_frame.columnconfigure(1, weight=1)
    button_layout_frame.columnconfigure(2, weight=1)

    # 左侧按钮区域
    left_button_frame = tk.Frame(button_layout_frame)
    left_button_frame.grid(row=0, column=0, padx=5, sticky="w")

    # 右侧按钮区域
    right_button_frame = tk.Frame(button_layout_frame)
    right_button_frame.grid(row=0, column=1, padx=5)

    # 清空按钮区域
    clear_button_frame = tk.Frame(button_layout_frame)
    clear_button_frame.grid(row=0, column=2, padx=5, sticky="e")

    # 创建拆分按钮
    split_button = tk.Button(
        left_button_frame,
        text="批量拆分字幕",
        command=split_files,
        width=15,
        height=2,
        bg="#2196F3",
        fg="white"
    )
    split_button.pack()

    # 创建转换按钮
    convert_button = tk.Button(
        right_button_frame,
        text="srt字幕批量转换为txt",
        command=convert_files,
        width=18,
        height=2,
        bg="#4CAF50",
        fg="white"
    )
    convert_button.pack()

    # 创建清空列表按钮
    clear_button = tk.Button(
        clear_button_frame,
        text="清空列表",
        command=clear_file_list,
        width=10,
        height=1,
        bg="#f0f0f0"
    )
    clear_button.pack()

    # 运行主循环
    root.mainloop()